Mulgrad
计算逐元素乘法 (Mul) 操作的梯度。该算子是 Mul 算子的反向传播(backward pass)部分。梯度的计算遵循链式法则。
其中 dx0 和 dx1 分别是损失函数对前向输入 Input0 和 Input1 的梯度。
Gradmul1L版本专门用于 x1 张量维度大于或等于 x2 张量的广播场景。Gradmul2l版本专门用于 x2 张量维度大于或等于 x1 张量的广播场景。
- 输入:
dy - 来自后一层的上游梯度张量。
x1 - 前向传播时的第一个输入张量(被除数)。
x2 - 前向传播时的第二个输入张量(除数)。
- params - 参数打包成结构体Parameter:
tile_data0 - 临时工作空间地址。
tile_data1 - 临时工作空间地址。
large_shape - x1 和 x2 中维度较大的张量的形状。
small_shape - x1 和 x2 中维度较小的张量的形状。
out_shape - 输出张量 dx1 和 dx2 的形状。
ndims - 张量的维度数。
dy_size - dy元素个数,需要初始化。
x1_size - x1元素个数,需要初始化。
x2_size - x2元素个数,需要初始化。
large_strides - 维度较大张量的步长信息。
small_strides - 维度较小张量的步长信息。
out_strides - 输出张量的步长信息。
large_multiples - 维度较大张量的广播倍数。
small_multiples - 维度较小张量的广播倍数。
indices - 用于广播计算的临时索引空间地址。
x1_shape - x1的维度信息地址。
x2_shape - x2的维度信息地址。
core_mask - 核掩码。
- 输出:
dx1 - 写入计算出的对 x1 的梯度。
dx2 - 写入计算出的对 x2 的梯度。
- 支持平台:
FT78NEMT7004
备注
FT78NE 支持fp32
MT7004 支持fp16, fp32
共享存储版本:
-
void fp_mul_grad_s(float *dy, float *dx1, float *dx2, float *x1, float *x2, Parameter *params, int core_mask)
-
void hp_mul_grad_s(half *dy, half *dx1, half *dx2, half *x1, half *x2, Parameter *params, int core_mask)
C调用示例:
1//FT78NE示例
2#include <stdio.h>
3#include <divgrad.h>
4int main(int argc, char* argv[]) {
5 float *dy = (float *)0x81000000;//输入,初始化
6 float *dx1 = (float *)0x82000000;//输出,需要初始化
7 float *dx2 = (float *)0x83000000;//输出,需要初始化
8 float *x1_data = (float *)0x84000000;//输入,需要初始化
9 float *x2_data = (float *)0x85000000;//输入,需要初始化
10 float *tile_data0 = (float *)0x86000000;//中间结果,不需要初始化
11 float *tile_data1 = (float *)0x87000000;//中间结果,不需要初始化
12
13 long long ndims = 4;
14 long long dy_size;
15 long long x1_size;
16 long long x2_size;
17
18 int *large_strides = (int *)0x91000000;//不需要初始化
19 int *small_strides = (int *)0x91001000; //不需要初始化
20 int *out_strides = (int *)0x91002000; //不需要初始化
21 int *large_multiples = (int *)0x91003000; //不需要初始化
22 int *small_multiples = (int *)0x91004000; //不需要初始化
23 int *indices = (int *)0x91005000;
24 float *check_dx1 = (float *)0x91006000;
25 float *check_dx2 = (float *)0x1007000;
26
27 int i = 0;
28 srand(seed++);
29
30 //初始化
31 int x1_shape[4] = {8, 1, 8, 8};
32 int x2_shape[4] = {8, 8, 8, 8};
33
34 int* large_shape = (int*) x2_shape;
35 int* small_shape = (int*) x1_shape;
36 int *out_shape = (int*) x2_shape;
37
38 dy_size = out_shape[0] * out_shape[1] * out_shape[2] * out_shape[3];
39 x1_size = x1_shape[0] * x1_shape[1] * x1_shape[2] * x1_shape[3];
40 x2_size = x2_shape[0] * x2_shape[1] * x2_shape[2] * x2_shape[3];
41
42 for(i = 0; i < dy_size; ++i) {
43 dy[i] = (float)(rand()%20)/2;
44 }
45
46 for(i = 0; i < x1_size; ++i) {
47 x1_data[i] = (float)(rand()%20)/2;
48 }
49
50 for(i = 0; i < x2_size; ++i) {
51 x2_data[i] = (float)(rand()%20)/2;
52 }
53
54 memset(indices, 0, ndims*sizeof(int));
55
56 Parameter params;
57 params.tile_data0 = tile_data0;
58 params.tile_data1 = tile_data1;
59 params.large_shape = large_shape;
60 params.small_shape = small_shape;
61 params.out_shape = out_shape;
62 params.ndims = ndims;
63 params.dy_size = dy_size;
64 params.x1_size = x1_size;
65 params.x2_size = x2_size;
66 params.large_strides = large_strides;
67 params.small_strides = small_strides;
68 params.out_strides = out_strides;
69 params.large_multiples = large_multiples;
70 params.small_multiples = small_multiples;
71 params.indices = indices;
72 params.x1_shape = x1_shape;
73 params.x2_shape = x2_shape;
74
75 int core_mask = 0b1111;
76 /*性能统计*/
77 fp_mul_grad_s(dy, dx1, dx2, x1_data, x2_data, ¶ms, core_mask);
78 return 0;
79}
私有存储版本:
-
void fp_mul_grad_p(float *dy, float *dx1, float *dx2, float *x1, float *x2, Parameter *params)
-
void hp_mul_grad_p(half *dy, half *dx1, half *dx2, half *x1, half *x2, Parameter *params)
C调用示例:
1//FT78NE示例
2#include <stdio.h>
3#include <mulgrad.h>
4int main(int argc, char* argv[]) {
5 float *dy = (float *)0x10010000;//输入,初始化
6 float *dx1 = (float *)0x10016000;//输出,需要初始化
7 float *dx2 = (float *)0x10020000;//输出,需要初始化
8 float *x1_data = (float *)0x10026000;//输入,需要初始化
9 float *x2_data = (float *)0x10030000;//输入,需要初始化
10 float *tile_data0 = (float *)0x10036000;//中间结果,不需要初始化
11 float *tile_data1 = (float *)0x10040000;//中间结果,不需要初始化
12
13 long long ndims = 4;
14 long long dy_size;
15 long long x1_size;
16 long long x2_size;
17
18 int *large_strides = (int *)0x10050000;//不需要初始化
19 int *small_strides = (int *)0x10051000; //不需要初始化
20 int *out_strides = (int *)0x10052000; //不需要初始化
21 int *large_multiples = (int *)0x10053000; //不需要初始化
22 int *small_multiples = (int *)0x10054000; //不需要初始化
23 int *indices = (int *)0x10055000;
24 float *check_dx1 = (float *)0x10060000;
25 float *check_dx2 = (float *)0x10070000;
26
27 int i = 0;
28 srand(seed++);
29
30 /*
31 1024时 large_shape = {4, 8, 4, 8}, small_shape = {4, 8, 4, 8}
32 4096时 large_shape = {8, 8, 8, 8}, small_shape = {8, 8, 8, 8}
33 */
34 //初始化
35 int x1_shape[4] = {8, 1, 8, 8};
36 int x2_shape[4] = {8, 8, 8, 8};
37
38 int* large_shape = (int*) x2_shape;
39 int* small_shape = (int*) x1_shape;
40 int *out_shape = (int*) x2_shape;
41
42 dy_size = out_shape[0] * out_shape[1] * out_shape[2] * out_shape[3];
43 x1_size = x1_shape[0] * x1_shape[1] * x1_shape[2] * x1_shape[3];
44 x2_size = x2_shape[0] * x2_shape[1] * x2_shape[2] * x2_shape[3];
45
46 for(i = 0; i < dy_size; ++i) {
47 dy[i] = (float)(rand()%20)/2;
48 }
49
50 for(i = 0; i < x1_size; ++i) {
51 x1_data[i] = (float)(rand()%20)/2;
52 }
53
54 for(i = 0; i < x2_size; ++i) {
55 x2_data[i] = (float)(rand()%20)/2;
56 }
57
58 memset(indices, 0, ndims*sizeof(int));
59
60 Parameter params;
61
62 params.tile_data0 = tile_data0;
63 params.tile_data1 = tile_data1;
64 params.large_shape = large_shape;
65 params.small_shape = small_shape;
66 params.out_shape = out_shape;
67 params.ndims = ndims;
68 params.dy_size = dy_size;
69 params.x1_size = x1_size;
70 params.x2_size = x2_size;
71 params.large_strides = large_strides;
72 params.small_strides = small_strides;
73 params.out_strides = out_strides;
74 params.large_multiples = large_multiples;
75 params.small_multiples = small_multiples;
76 params.indices = indices;
77 params.x1_shape = x1_shape;
78 params.x2_shape = x2_shape;
79
80 fp_mul_grad_p(dy, dx1, dx2, x1_data, x2_data, ¶ms);
81 return 0;
82}